{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "51CTO课程频道：http://edu.51cto.com/lecturer/index/user_id-12330098.html<br>\n",
    "优酷频道：http://i.youku.com/sdxxqbf<br>\n",
    "微信公众号：深度学习与神经网络<br>\n",
    "Github：https://github.com/Qinbf<br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#coding:utf-8\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import tensorflow as tf\n",
    "import pickle\n",
    "import math\n",
    "from six.moves import xrange"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                     0                                                  1  \\\n",
      "0  6215603645409872328  c924,c531,c102,c284,c188,c104,c98,c107,c11,c11...   \n",
      "1  6649324930261961840  c346,c1549,c413,c294,c675,c504,c183,c74,c541,c...   \n",
      "2 -4251899610700378615  c96,c97,c97,c98,c99,c100,c101,c141,c42,c42,c10...   \n",
      "3  6213817087034420233  c504,c157,c221,c221,c633,c468,c469,c1637,c1072...   \n",
      "4 -8930652370334418373  c0,c310,c35,c122,c123,c11,c317,c91,c175,c476,c...   \n",
      "\n",
      "                                                   2  \\\n",
      "0  w1340,w1341,w55,w1344,w58,w6,w24178,w26959,w47...   \n",
      "1  w40132,w1357,w1556,w1380,w2464,w33,w16791,w109...   \n",
      "2  w53,w54,w1779,w54,w1309,w54,w369,w949,w65587,w...   \n",
      "3  w5083,w12537,w10427,w29724,w6,w2566,w11,w18476...   \n",
      "4  w33792,w21,w83,w6,w21542,w21,w140670,w25,w1110...   \n",
      "\n",
      "                                                   3  \\\n",
      "0  c1128,c529,c636,c572,c1321,c139,c540,c223,c510...   \n",
      "1                                                NaN   \n",
      "2  c149,c148,c148,c42,c185,c95,c95,c186,c186,c186...   \n",
      "3  c15,c131,c39,c40,c85,c166,c969,c2456,c17,c636,...   \n",
      "4                                                NaN   \n",
      "\n",
      "                                                   4  \n",
      "0  w4094,w1618,w20104,w19234,w1097,w1005,w4228,w2...  \n",
      "1                                                NaN  \n",
      "2                                                NaN  \n",
      "3  w2550,w24,w239,w98,w19456,w11,w108710,w3483,w2...  \n",
      "4                                                NaN  \n"
     ]
    }
   ],
   "source": [
    "# 导入question_train_set\n",
    "reader = pd.read_table('./ieee_zhihu_cup/question_eval_set.txt',sep='\\t',header=None)\n",
    "print(reader.iloc[0:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('max_document_length:', 76)\n"
     ]
    }
   ],
   "source": [
    "# 计算一段文本中最大词汇数\n",
    "x_text = reader.iloc[:,2]\n",
    "max_document_length = 0\n",
    "for i,line in enumerate(x_text):\n",
    "    try:\n",
    "        temp = line.split(',')\n",
    "        max_document_length = max(max_document_length,len(temp))\n",
    "    except:\n",
    "        # 其中有一行数据为空\n",
    "        pass\n",
    "#         x_text[i] = \" \"\n",
    "\n",
    "print(\"max_document_length:\",max_document_length)\n",
    "\n",
    "# 载入字典\n",
    "vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor.restore(\"vocab_dict\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 按','切分数据\n",
    "text = []\n",
    "for line in x_text:\n",
    "    try:\n",
    "        text.append(line.split(','))\n",
    "    except:\n",
    "        # 其中有一行数据为空\n",
    "        text.append(' ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 217360/217360 [00:05<00:00, 40820.07it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[  4507,   2664,    423,   3387,    425,     10,  84669,   1744,\n",
       "           152,     13,     90,    152,   1556,    403,  17192,     10,\n",
       "          3686,     13,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0],\n",
       "       [ 18531,    861,   1538,    490,  16758,    197,   4225,    658,\n",
       "         18551,     10,   4100,     15,   1929,     52,     13,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0],\n",
       "       [  1207,     19,    810,     19, 126081,     19,    501,   2249,\n",
       "         85078,     35,    218,    308,     99,    105,    313,     13,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0],\n",
       "       [  1040,  11856,    360,  23102,     10,   4100,      4,    432,\n",
       "            17,   1424,      0,     13,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0],\n",
       "       [  3538,    137,   1628,     10,   8450,    137,      0,     16,\n",
       "            17,     13,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0,      0,      0,      0,      0,\n",
       "             0,      0,      0,      0]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 把数据集变成编号的形式\n",
    "x = []\n",
    "for line in tqdm(text):\n",
    "    line_len = len(line)\n",
    "    text2num = []\n",
    "    for i in xrange(max_document_length):\n",
    "        if(i < line_len):\n",
    "            try:\n",
    "                text2num.append(vocab_processor.vocabulary_.get(line[i])) # 把词转为数字\n",
    "            except:\n",
    "                text2num.append(0) # 没有对应的词\n",
    "        else:\n",
    "            text2num.append(0) # 填充0\n",
    "    x.append(text2num)\n",
    "x = np.array(x)\n",
    "x[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def batch_iter(data, batch_size, num_epochs, shuffle=False):\n",
    "    \"\"\"\n",
    "    Generates a batch iterator for a dataset.\n",
    "    \"\"\"\n",
    "    data = np.array(data)\n",
    "    data_size = len(data)\n",
    "    # 每个epoch的num_batch\n",
    "    num_batches_per_epoch = int((len(data) - 1) / batch_size) + 1\n",
    "    print(\"num_batches_per_epoch:\",num_batches_per_epoch)\n",
    "    for epoch in range(num_epochs):\n",
    "        # Shuffle the data at each epoch\n",
    "        if shuffle:\n",
    "            shuffle_indices = np.random.permutation(np.arange(data_size))\n",
    "            shuffled_data = data[shuffle_indices]\n",
    "        else:\n",
    "            shuffled_data = data\n",
    "        for batch_num in range(num_batches_per_epoch):\n",
    "            start_index = batch_num * batch_size\n",
    "            end_index = min((batch_num + 1) * batch_size, data_size)\n",
    "            yield shuffled_data[start_index:end_index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def eval(predict_label_and_marked_label_list):\n",
    "    \"\"\"\n",
    "    :param predict_label_and_marked_label_list: 一个元组列表。例如\n",
    "    [ ([1, 2, 3, 4, 5], [4, 5, 6, 7]),\n",
    "      ([3, 2, 1, 4, 7], [5, 7, 3])\n",
    "     ]\n",
    "    需要注意这里 predict_label 是去重复的，例如 [1,2,3,2,4,1,6]，去重后变成[1,2,3,4,6]\n",
    "    \n",
    "    marked_label_list 本身没有顺序性，但提交结果有，例如上例的命中情况分别为\n",
    "    [0，0，0，1，1]   (4，5命中)\n",
    "    [1，0，0，0，1]   (3，7命中)\n",
    "\n",
    "    \"\"\"\n",
    "    right_label_num = 0  #总命中标签数量\n",
    "    right_label_at_pos_num = [0, 0, 0, 0, 0]  #在各个位置上总命中数量\n",
    "    sample_num = 0   #总问题数量\n",
    "    all_marked_label_num = 0    #总标签数量\n",
    "    for predict_labels, marked_labels in predict_label_and_marked_label_list:\n",
    "        sample_num += 1\n",
    "        marked_label_set = set(marked_labels)\n",
    "        all_marked_label_num += len(marked_label_set)\n",
    "        for pos, label in zip(range(0, min(len(predict_labels), 5)), predict_labels):\n",
    "            if label in marked_label_set:     #命中\n",
    "                right_label_num += 1\n",
    "                right_label_at_pos_num[pos] += 1\n",
    "\n",
    "    precision = 0.0\n",
    "    for pos, right_num in zip(range(0, 5), right_label_at_pos_num):\n",
    "        precision += ((right_num / float(sample_num))) / math.log(2.0 + pos)  # 下标0-4 映射到 pos1-5 + 1，所以最终+2\n",
    "    recall = float(right_label_num) / all_marked_label_num\n",
    "\n",
    "    return 2*(precision * recall) / (precision + recall )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('num_filters_total:', 3072)\n"
     ]
    }
   ],
   "source": [
    "# 定义三个placeholder\n",
    "input_x = tf.placeholder(tf.int32, [None, x.shape[1]], name=\"input_x\")\n",
    "dropout_keep_prob = tf.placeholder(tf.float32, name=\"dropout_keep_prob\")\n",
    "\n",
    "# sequence_length-最长词汇数\n",
    "sequence_length=x.shape[1]\n",
    "# num_classes-分类数\n",
    "num_classes=1999\n",
    "# vocab_size-总词汇数\n",
    "vocab_size=len(vocab_processor.vocabulary_)\n",
    "# embedding_size-词向量长度\n",
    "embedding_size=256\n",
    "# filter_sizes-卷积核尺寸3，4，5\n",
    "filter_sizes=list(map(int, [3,4,5]))\n",
    "# num_filters-卷积核数量\n",
    "num_filters=1024\n",
    "\n",
    "Weights = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0), name=\"Weights\")\n",
    "# [None, sequence_length, embedding_size]\n",
    "embedded_chars = tf.nn.embedding_lookup(Weights, input_x)\n",
    "# 添加一个维度，[None, sequence_length, embedding_size, 1]\n",
    "embedded_chars_expanded = tf.expand_dims(embedded_chars, -1)\n",
    "# Create a convolution + maxpool layer for each filter size\n",
    "pooled_outputs = []\n",
    "for i, filter_size in enumerate(filter_sizes):\n",
    "    with tf.name_scope(\"conv-maxpool-%s\" % filter_size):\n",
    "        # Convolution Layer\n",
    "        filter_shape = [filter_size, embedding_size, 1, num_filters]\n",
    "        W = tf.Variable(\n",
    "            tf.truncated_normal(filter_shape, stddev=0.1), name=\"W\")\n",
    "        b = tf.Variable(\n",
    "            tf.constant(0.1, shape=[num_filters]), name=\"b\")\n",
    "        conv = tf.nn.conv2d(\n",
    "            embedded_chars_expanded,\n",
    "            W,\n",
    "            strides=[1, 1, 1, 1],\n",
    "            padding=\"VALID\",\n",
    "            name=\"conv\")\n",
    "        # Apply nonlinearity\n",
    "        h = tf.nn.relu(tf.nn.bias_add(conv, b), name=\"relu\")\n",
    "        # Maxpooling over the outputs\n",
    "        pooled = tf.nn.max_pool(\n",
    "            h,\n",
    "            ksize=[1, sequence_length - filter_size + 1, 1, 1],\n",
    "            strides=[1, 1, 1, 1],\n",
    "            padding='VALID',\n",
    "            name=\"pool\")\n",
    "        pooled_outputs.append(pooled)\n",
    "\n",
    "# Combine all the pooled features\n",
    "num_filters_total = num_filters * len(filter_sizes)\n",
    "print(\"num_filters_total:\", num_filters_total)\n",
    "h_pool = tf.concat(pooled_outputs, 3)\n",
    "h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total])\n",
    "\n",
    "# Add dropout\n",
    "with tf.name_scope(\"dropout\"):h_drop = tf.nn.dropout(h_pool_flat,dropout_keep_prob)\n",
    "\n",
    "# Final (unnormalized) scores and predictions\n",
    "with tf.name_scope(\"output\"):\n",
    "    W = tf.get_variable(\n",
    "        \"W\",\n",
    "        shape=[num_filters_total, num_classes],\n",
    "        initializer=tf.contrib.layers.xavier_initializer())\n",
    "    b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name=\"b\")\n",
    "    scores = tf.nn.xw_plus_b(h_drop, W, b, name=\"scores\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ./models/model_-7200\n",
      "('num_batches_per_epoch:', 218)\n",
      "('Evaluation:step', 5)\n",
      "('Evaluation:step', 10)\n",
      "('Evaluation:step', 15)\n",
      "('Evaluation:step', 20)\n",
      "('Evaluation:step', 25)\n",
      "('Evaluation:step', 30)\n",
      "('Evaluation:step', 35)\n",
      "('Evaluation:step', 40)\n",
      "('Evaluation:step', 45)\n",
      "('Evaluation:step', 50)\n",
      "('Evaluation:step', 55)\n",
      "('Evaluation:step', 60)\n",
      "('Evaluation:step', 65)\n",
      "('Evaluation:step', 70)\n",
      "('Evaluation:step', 75)\n",
      "('Evaluation:step', 80)\n",
      "('Evaluation:step', 85)\n",
      "('Evaluation:step', 90)\n",
      "('Evaluation:step', 95)\n",
      "('Evaluation:step', 100)\n",
      "('Evaluation:step', 105)\n",
      "('Evaluation:step', 110)\n",
      "('Evaluation:step', 115)\n",
      "('Evaluation:step', 120)\n",
      "('Evaluation:step', 125)\n",
      "('Evaluation:step', 130)\n",
      "('Evaluation:step', 135)\n",
      "('Evaluation:step', 140)\n",
      "('Evaluation:step', 145)\n",
      "('Evaluation:step', 150)\n",
      "('Evaluation:step', 155)\n",
      "('Evaluation:step', 160)\n",
      "('Evaluation:step', 165)\n",
      "('Evaluation:step', 170)\n",
      "('Evaluation:step', 175)\n",
      "('Evaluation:step', 180)\n",
      "('Evaluation:step', 185)\n",
      "('Evaluation:step', 190)\n",
      "('Evaluation:step', 195)\n",
      "('Evaluation:step', 200)\n",
      "('Evaluation:step', 205)\n",
      "('Evaluation:step', 210)\n",
      "('Evaluation:step', 215)\n"
     ]
    }
   ],
   "source": [
    "# 选择模型\n",
    "checkpoint_file = \"./models/model-10000\"\n",
    "    \n",
    "with tf.Session() as sess:\n",
    "    predict_top_5 = tf.nn.top_k(scores, k=5)\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    i = 0\n",
    "    saver = tf.train.Saver()\n",
    "    saver.restore(sess, checkpoint_file)\n",
    "\n",
    "    # Generate batches\n",
    "    batches = batch_iter(list(x), 1000, 1)\n",
    "    \n",
    "    for x_batch in batches:\n",
    "        i = i + 1\n",
    "        predict_5 = sess.run(predict_top_5,feed_dict={input_x:x_batch,dropout_keep_prob:1.0})\n",
    "        if i == 1:\n",
    "            predict = predict_5[1]\n",
    "        else:\n",
    "            predict = np.concatenate((predict,predict_5[1]))\n",
    "        if (i%5==0):\n",
    "            print (\"Evaluation:step\",i)\n",
    "\n",
    "    np.savetxt(\"predict.txt\",predict,fmt='%d')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
